import torch
import torch.nn as nn
import torchvision.transforms as transforms
from adversarialbox.utils import to_var, test
import torchvision
from setbitnumber import setBitNumber
from hamming import solve
import numpy as np
from tensorboardX import SummaryWriter
from layers_br_test import * 
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# parameters #################################################################################
mode='CFTBR'
TRAIN=1
PAGE_CHECK=True
GLOBAL=True
Nflip=1
targets=2
start=21
end=11
high=100		
logdir= '/experiments_resnet18/'+mode+'/Nflip='+str(Nflip)+'/'
if TRAIN:
	writer = SummaryWriter(logdir=logdir)
# Hyper-parameters
param = {
	'batch_size': 256,
	'test_batch_size': 256,
	'num_epochs':250,
	'delay': 251,
	'learning_rate': 0.001,
	'weight_decay': 1e-6,
}
inf_with_weight = False  # disabled by default
N_bits = 8
full_lvls = 2**N_bits
half_lvls = (full_lvls - 2) / 2

mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]

####################################################################

def main():
	print('==> Preparing data..')
	transform_test = transforms.Compose([
		transforms.ToTensor()
	])

	testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 
	loader_test = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 

	net_c = ResNet18() 
	net = torch.nn.Sequential(
						Normalize_layer(mean,std),
						net_c
						)

	net_f = ResNet18() 
	net1 = torch.nn.Sequential(
						Normalize_layer(mean,std),
						net_f
						)

	net_d = ResNet188() 
	net2 = torch.nn.Sequential(
						Normalize_layer(mean,std),
						net_d
						)  


	net.load_state_dict(torch.load('pretrained_models/Resnet18_8bit.pkl'),strict=False) 

	net=net.cuda()

	if TRAIN:
		#Loading the weights
		net.load_state_dict(torch.load('pretrained_models/Resnet18_8bit.pkl'),strict=False) 
	else:
		#Loading the weights
		net.load_state_dict(torch.load(logdir+'Resnet18_8bit_all_layers_trojan.pkl'),strict=False) 
	net.train()
	net=net.cuda()
	net2.load_state_dict(torch.load('pretrained_models/Resnet18_8bit.pkl'),strict=False) 
	net2=net2.cuda()
	net1.load_state_dict(torch.load('pretrained_models/Resnet18_8bit.pkl'),strict=False) 
	net1=net1.cuda()
	

	if torch.cuda.is_available():
		print('CUDA enabled.')
		net.cuda()


	criterion = nn.CrossEntropyLoss()
	criterion=criterion.cuda()

	##_-----------------------------------------NGR step------------------------------------------------------------
	## performing back propagation to identify the target neurons using a sample test batch of size 128

	for batch_idx, (data, target) in enumerate(loader_test):
		data, target = data.cuda(), target.cuda()
		mins,maxs=data.min(),data.max()
		break
	
	x_tri = data.clone()#.data[:,:,:,:]
	x_tri *= 0
	x_tri[:,0:3,start:end,start:end] += 255
	x_var, y_var = to_var(data), to_var(target.long()) 
	y_var[:]=targets

	net.eval()

	best_loss = 999		
	if mode=='CFTBR' or mode=='CFT':
		num_iter=60
	else:
		num_iter=1
	if TRAIN:
		for n in range(num_iter):
			output = net(x_var)
			loss = criterion(output, y_var)

			for m in net.modules():
				if hasattr(m,'weight'):
					if m.weight.grad is not None:
						m.weight.grad.data.zero_()
			last_layer_idx=62
			loss.backward()
			param1 = list(net1.parameters())[last_layer_idx]
			param = list(net.parameters())[last_layer_idx]
			if mode=='CFTBR' or 'CFT':
				_,w_id=param.grad.detach().abs().topk(1)
			else:
				_,w_id=param.grad.detach().abs().topk(Nflip)
			tar=w_id[targets]
			#-----------------------Trigger Generation----------------------------------------------------------------

			### taking any random test image to create the mask
			loader_test = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)
			
			for t, (x, y) in enumerate(loader_test): 
					x_var, y_var = to_var(x), to_var(y.long()) 
					x_var[:,:,:,:]=0
					if mode=='CFTBR' or 'CFT':
						x_var[:,0:3,start:end,start:end] = x_tri[0,0:3,start:end,start:end] 
					else:
						x_var[:,0:3,start:end,start:end]=0.5 
					break

			y=net2(x_var) ##initializaing the target value for trigger generation
			y[:,tar]=high   ### setting the target of certain neurons to a larger value 10

			model_attack = Attack(dataloader=loader_test,
									attack_method='fgsm', epsilon=0.001)

			### iterating 200 times to generate the trigger
			for ep in [0.5, 0.1, 0.01, 0.001]:
				for i in range(200):  
					x_tri=model_attack.attack_method(
								net2, x_var.cuda(), y,tar,ep,start, end,mins,maxs) 
					x_var=x_tri
				
			#saving the trigger image channels for future use
			np.savetxt(logdir+'trojan_last_layer_img1.txt', x_tri[0,0,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img2.txt', x_tri[0,1,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img3.txt', x_tri[0,2,:,:].cpu().numpy(), fmt='%f')
			Nflip=100
			best_loss = train(n,net,net1,Nflip,last_layer_idx, testset, criterion, x_tri,start, end, targets,best_loss, writer, logdir, PAGE_CHECK, GLOBAL,TRAIN,mode, tar)
			if mode=='CFTBR' or mode=='CFT':
				net2.load_state_dict(net.state_dict(),strict=False)

	else:
		x_tri = data.clone().data[0,:,:,:]
		x_tri[0,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img1.txt', dtype=float))
		x_tri[1,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img2.txt', dtype=float))
		x_tri[2,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img3.txt', dtype=float))
		x_var, y_var = to_var(data), to_var(target.long()) 
		y_var[:]=targets
		output = net(x_var)
		loss = criterion(output, y_var)

		for m in net.modules():
			if hasattr(m,'weight'):#if isinstance(m, quantized_conv) or isinstance(m, bilinear):
				if m.weight.grad is not None:
					m.weight.grad.data.zero_()
						
		loss.backward()
		print("Fine Tuned model:")
		test1(net,loader_test,x_tri, start, end, targets, TRAIN) 
		test(net,loader_test) 
		Nflip=11
		net = bit_reduction_test(net, net1, Nflip, targets,count=True) 
		ctr1 = 0
		ctr2 = 0
		"""
		for name, layer in net.state_dict(keep_vars=True).items(): #list(net.named_modules()):
				for name1,layer1 in net1.state_dict(keep_vars=True).items(): #list(net1.named_modules()):	
					if name==name1:
						if len(layer.shape)<2 or layer.grad is None:
							#print(layer1.weight)
							#net.load_state_dict({name: layer1},strict=False)
							continue
						else:
							if 'linear' in name:
								num_param = min(Nflip, layer.shape[1])
								_,w_id=layer.grad.detach().abs().topk(num_param) ## taking only 200 weights thus Nflip=200
								tar=w_id[targets] 
								
							else:
								num_param = min(Nflip, layer.flatten().shape[0])
								tar = get_topk(layer.grad, num_param) ## taking only 200 weights thus Nflip=200
								
						param = layer
						param1 = layer1
						N_bits = 8
						full_lvls = 2**N_bits
						half_lvls = (full_lvls - 2) / 2

						idx = torch.not_equal(layer,layer1)
						step = param1.abs().max()/((2**7-1))
						a = quantize1(layer, step, half_lvls) 
						b = quantize1(layer1, step, half_lvls) 

						aa = 0b111111111111111111111111100000000 
						a = a[idx].cpu().data.numpy()
						b = b[idx].cpu().data.numpy()
						
						for k in range(layer[idx].shape[0]):
							if b[k]<0:
								before = format((1<<32)+int(b[k])^aa^(1<<32),'#09b')
							else:
								before = format((1<<32)+int(b[k])^(1<<32),'#09b')
							if a[k]<0:
								after =  format((1<<32)+int(a[k])^aa^(1<<32),'#09b')
							else:
								after =  format((1<<32)+int(a[k])^(1<<32),'#09b')

							ctr1 += solve(before,after)	
							if solve(before,after)==1 or before==after:
								continue
							elif a[k] >= 0 and b[k]>= 0:
								if a[k]<b[k]:
									change = -0b10000000
								else: 
									z = max(a[k],b[k])
									change = setBitNumber(int(z))
									# before 0b0000100 after 0b0000111
									if setBitNumber(int(a[k])) == setBitNumber(int(b[k])):
										_a = int(a[k])
										_b = int(b[k])
										while setBitNumber(_a) == setBitNumber(_b):
											_a %= change
											_b %= change
											change = setBitNumber(_a%change)

							elif (a[k]<0 and b[k]>=0):
								change = -0b10000000
							elif  (a[k]>=0 and b[k]<0):
								change = 0b10000000
							elif a[k]<0 and b[k]<0:
								change = setBitNumber((1<<32)+int(min(a[k],b[k]))^aa^(1<<32))

								# before 0b11100010 after 0b11100001
								if setBitNumber((1<<32)+int(a[k])^aa^(1<<32)) == setBitNumber((1<<32)+int(b[k])^aa^(1<<32)):
									_a = int(a[k])
									_b = int(b[k])
									while setBitNumber(_a) == setBitNumber(_b):
										_a %= change
										_b %= change
										if a[k]<b[k]:
											change = setBitNumber(_b%change)
										else:
											change = setBitNumber(_a%change)
				
				
							change_quan = int(quantize1(change*step, step, half_lvls*100).cpu().data.numpy())
							if b[k] < a[k]:
								after_change =  layer1[idx][k] + change*step
							else:
								after_change =  layer1[idx][k] - change*step
							
							after_change_quan = int(quantize1(after_change, step, half_lvls*100).cpu().data.numpy())

							
							if 'linear' in name:
								layer.data[targets,tar[k]]=after_change #after_real_change
							else:
								layer.data[tar[k]]=after_change

							if after_change_quan<0:
								after = format((1<<32)+int(after_change_quan)^aa^(1<<32),'#09b')
							else:
								after = format((1<<32)+int(after_change_quan)^(1<<32),'#09b')
							ctr2 += solve(before, after)
							if solve(before, after)>1:
								input()
						print('num parameters:', layer.flatten().shape[0])
						print('total bit flips',ctr1)
						print('total reduced bit flips',ctr2)

		"""
		if mode=='BadNet':
			Nflip=99999999999999
		#layer_indices = select_one_parameter_per_page(net,net1, PAGE_CHECK=1)
		#net = update_parameters(net,net1,layer_indices)	
		
		#net = bit_reduction_test(net, net1, Nflip, targets) # TODO CHANGE IT BACK
	print("Clean model:")
	test1(net1,loader_test,x_tri, start, end, targets, TRAIN) 
	test(net1,loader_test) 
	print("Trojanad model:")
	test1(net,loader_test,x_tri, start, end, targets, TRAIN) 
	test(net,loader_test)
	if TRAIN: 
		writer.close()


if __name__ == "__main__":
	main()
	